{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# `metadsl` + `uarray`\n",
    "\n",
    "First, let's installed the latest uarray:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting git+https://github.com/Quansight-Labs/uarray.git\n",
      "  Cloning https://github.com/Quansight-Labs/uarray.git to /private/var/folders/m7/t8dvwtnn32z84333p845tly40000gn/T/pip-req-build-vv251_40\n",
      "Building wheels for collected packages: uarray\n",
      "  Building wheel for uarray (setup.py) ... \u001b[?25ldone\n",
      "\u001b[?25h  Stored in directory: /private/var/folders/m7/t8dvwtnn32z84333p845tly40000gn/T/pip-ephem-wheel-cache-6m7r4pvj/wheels/3d/1a/bf/60f787ed8f0ac071de28d869eb644793856923ac4688c14544\n",
      "Successfully built uarray\n",
      "Installing collected packages: uarray\n",
      "  Found existing installation: uarray 0.4+168.g345664c\n",
      "    Uninstalling uarray-0.4+168.g345664c:\n",
      "      Successfully uninstalled uarray-0.4+168.g345664c\n",
      "Successfully installed uarray-0.4+228.g09d6ad2\n"
     ]
    }
   ],
   "source": [
    "!pip install numpy\n",
    "!pip install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp37-cp37m-linux_x86_64.whl\n",
    "!pip install torchvision\n",
    "!pip install -U git+https://github.com/Quansight-Labs/uarray.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`uarray`  and `unumpy` provide ways to execute on different backends like Torch, NumPy, XND.\n",
    "\n",
    "`metadsl` provides a way to optimize your computation befor executing.\n",
    "\n",
    "Here, we will show how we can integrate them to allow users to build up an expression with `unumpy`, optimize it with `metadsl`, and then execute it with `unumpy`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import uarray\n",
    "import unumpy\n",
    "import typing\n",
    "import unumpy.multimethods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import metadsl"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's map certain unumpy methods to metadsl functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "metadsl_backend = uarray.Backend()\n",
    "uarray.register_backend(metadsl_backend)\n",
    "\n",
    "class Array(metadsl.Instance):\n",
    "    pass\n",
    "\n",
    "\n",
    "@metadsl.call(lambda start, stop, stride: Array)\n",
    "def arange(start: int, stop: int, stride: int) -> Array:\n",
    "    ...\n",
    "\n",
    "@metadsl.call(lambda shape: Array)\n",
    "def zeros(shape: typing.Tuple[int]) -> Array:\n",
    "    ...\n",
    "\n",
    "@metadsl.call(lambda left, right: Array)\n",
    "def add(left: Array, right: Array) -> Array:\n",
    "    ...\n",
    "\n",
    "@metadsl.call(lambda a: Array)\n",
    "def sum(a: Array) -> Array:\n",
    "    ...\n",
    "\n",
    "@metadsl.call(lambda: Array)\n",
    "def zero() -> Array:\n",
    "    ...\n",
    "    \n",
    "METADSL_TO_UNUMPY = {\n",
    "    arange: unumpy.arange,\n",
    "    zeros: unumpy.zeros,\n",
    "    add: unumpy.add,\n",
    "    sum: unumpy.sum \n",
    "}\n",
    "for m, u in METADSL_TO_UNUMPY.items():\n",
    "    uarray.register_implementation(u, metadsl_backend)(m)\n",
    "unumpy.ndarray.register_convertor(metadsl_backend, lambda i: i)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can use them to build up some expression in metadsl:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "with uarray.set_backend(metadsl_backend):\n",
    "    left = unumpy.arange(0, 10, 2)\n",
    "    right = unumpy.sum(unumpy.zeros(10))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'arange(0, 10, 2)'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "str(left)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'sum(zeros(10))'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "str(right)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's simplify summing zeros with just a zero:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "simplifications = metadsl.RulesRepeatFold()\n",
    "simplify = metadsl.RuleApplier(simplifications)\n",
    "\n",
    "@simplifications.append\n",
    "@metadsl.pure_rule(None)\n",
    "def _sum_zeros_zero(shape: typing.Tuple[int]):\n",
    "    return (\n",
    "        sum(zeros(shape)),\n",
    "        zero(),\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sum(zeros(10))\n",
      "zero()\n"
     ]
    }
   ],
   "source": [
    "print(str(right))\n",
    "print(str(simplify(right)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def execute(a):\n",
    "    if not isinstance(a, Array):\n",
    "        return a\n",
    "    return METADSL_TO_UNUMPY[a._call.function](*map(execute, a._call.args))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n"
     ]
    }
   ],
   "source": [
    "from unumpy.numpy_backend import NumpyBackend\n",
    "\n",
    "with uarray.set_backend(NumpyBackend):\n",
    "    print(execute(right))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.)\n"
     ]
    }
   ],
   "source": [
    "from unumpy.torch_backend import TorchBackend\n",
    "\n",
    "with uarray.set_backend(TorchBackend):\n",
    "    print(execute(right))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
